{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Pjmz-RORV8E"
      },
      "source": [
        "# Train a QA model\n",
        "\n",
        "The [Hugging Face Model Hub](https://huggingface.co/models) has a wide range of models that can handle many tasks. While these models perform well, the best performance is often found when fine-tuning a model with task-specific data. \n",
        "\n",
        "Hugging Face provides a [number of full-featured examples](https://github.com/huggingface/transformers/tree/master/examples) to assist with training task-specific models. When building models from the command line, these scripts are a great way to get started.\n",
        "\n",
        "txtai provides a training pipeline that can be used to train new models programatically using the Transformers Trainer framework.\n",
        "\n",
        "This example trains a small QA model and then further fine-tunes it with a couple new examples (few-shot learning)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dk31rbYjSTYm"
      },
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XMQuuun2R06J"
      },
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline-train] datasets pandas"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r6nmtieHdMfr"
      },
      "source": [
        "# Train a SQuAD 2.0 Model\n",
        "\n",
        "The first step is training a SQuAD 2.0 model. SQuAD is a question-answer dataset that poses a question with a context along with the identified answer. It's also possible to not have an answer. See the [SQuAD dataset website](https://rajpurkar.github.io/SQuAD-explorer/) for more information.\n",
        "\n",
        "We'll use a tiny Bert model with a portion of SQuAD 2.0 for efficiency purposes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 297
        },
        "id": "pg9-tUxEdRfk",
        "outputId": "06195c45-4b39-46e5-a462-af566f437ade"
      },
      "source": [
        "from datasets import load_dataset\n",
        "from txtai.pipeline import HFTrainer\n",
        "\n",
        "ds = load_dataset(\"squad_v2\")\n",
        "\n",
        "trainer = HFTrainer()\n",
        "trainer(\"google/bert_uncased_L-2_H-128_A-2\", ds[\"train\"].select(range(3000)), task=\"question-answering\", output_dir=\"bert-tiny-squadv2\")\n",
        "print(\"Training complete\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Reusing dataset squad_v2 (/root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)\n",
            "Loading cached processed dataset at /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d/cache-73bbe029cf3366fc.arrow\n",
            "Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']\n",
            "- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
            "- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
            "Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1131' max='1131' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [1131/1131 00:50, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>500</td>\n",
              "      <td>4.501800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1000</td>\n",
              "      <td>3.875900</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training complete\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zZtHxNSwFNGC"
      },
      "source": [
        "# Fine-tune with new data\n",
        "\n",
        "Next we'll add a few additional examples. Fine-tuning a QA model will help with framing a certain type of question or improve performance for a specific use-case. \n",
        "\n",
        "For smaller models with a narrow use case, this helps the model zero in on the types of questions that are to be asked. In this case, we want to tell the model exactly the types of information we're looking for when asking for ingredients. This will help improve confidence in the answers the model is generating.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 75
        },
        "id": "JBeScS5dFNeW",
        "outputId": "13596524-b476-4f55-f430-d25eeda9301f"
      },
      "source": [
        "# Training data\n",
        "data = [\n",
        "    {\"question\": \"What ingredient?\", \"context\": \"Pour 1 can whole tomatoes\", \"answers\": \"tomatoes\"},\n",
        "    {\"question\": \"What ingredient?\", \"context\": \"Dice 1 yellow onion\", \"answers\": \"onion\"},\n",
        "    {\"question\": \"What ingredient?\", \"context\": \"Cut 1 red pepper\", \"answers\": \"pepper\"},\n",
        "    {\"question\": \"What ingredient?\", \"context\": \"Peel and dice 1 clove garlic\", \"answers\": \"garlic\"},\n",
        "    {\"question\": \"What ingredient?\", \"context\": \"Put 1/2 lb beef\", \"answers\": \"beef\"},\n",
        "]\n",
        "\n",
        "model, tokenizer = trainer(\"bert-tiny-squadv2\", data, task=\"question-answering\", num_train_epochs=10)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [10/10 00:00, Epoch 10/10]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V7nAl3WtkBNK"
      },
      "source": [
        "# Test the model\n",
        "\n",
        "Now we're ready to test the results! The following sections run a question against the original model only trained with SQuAD 2.0 and the further fine-tuned model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "46fMiJrAIBu4",
        "outputId": "fb92ca0a-d433-486f-b61c-054f4e4a9b36"
      },
      "source": [
        "from transformers import pipeline\n",
        "\n",
        "questions = pipeline(\"question-answering\", model=\"bert-tiny-squadv2\")\n",
        "questions(\"What ingredient?\", \"Peel and dice 1 shallot\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'answer': 'dice 1 shallot',\n",
              " 'end': 23,\n",
              " 'score': 0.05128436163067818,\n",
              " 'start': 9}"
            ]
          },
          "metadata": {},
          "execution_count": 57
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nWQMRQm0NwdN",
        "outputId": "a1f15b5c-1cf8-4fe5-daa8-e19adc1700e1"
      },
      "source": [
        "from transformers import pipeline\n",
        "\n",
        "questions = pipeline(\"question-answering\", model=model.to(\"cpu\"), tokenizer=tokenizer)\n",
        "questions(\"What ingredient?\", \"Peel and dice 1 shallot\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'answer': 'shallot', 'end': 23, 'score': 0.13187439739704132, 'start': 16}"
            ]
          },
          "metadata": {},
          "execution_count": 58
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJoYksLbkZTJ"
      },
      "source": [
        "See how the results are more confident and have a better answer. This method allows using a smaller model with a narrow set of functionality with the upside of increased speed. Give it a try with your own data!"
      ]
    }
  ]
}
